import numpy as np
import torch
import torch.nn as nn
from model.satcl_modules import Linear
from copy import deepcopy
import torch.nn.functional as F


def compute_conv_output_size(Lin,
                             kernel_size,
                             stride=1,
                             padding=0,
                             dilation=1):
    return int(
        np.floor((Lin + 2 * padding - dilation *
                  (kernel_size - 1) - 1) / float(stride) + 1))


def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d or type(m) == Linear:
        # torch.nn.init.xavier_uniform(m.weight)
        torch.nn.init.kaiming_uniform_(
            m.weight, mode='fan_in', nonlinearity='relu')


def adjust_learning_rate(optimizer, epoch, args):
    for param_group in optimizer.param_groups:
        if (epoch == 1):
            param_group['lr'] = args.lr
        else:
            param_group['lr'] /= args.lr_factor

def get_model(model):
    return deepcopy(model.state_dict())


def set_model_(model, state_dict):
    model.load_state_dict(deepcopy(state_dict))
    return


def get_representation_matrix_mlp(task_id, net, device, x, y, old_task_distribution):
    example_data = []
    r = np.arange(x.size(0))
    np.random.shuffle(r)
    r = torch.LongTensor(r).to(device)
    un = torch.unique(y)
    idx = 0
    for _ in range(30):
        b = []
        for i in un:
            while y[idx] != i:
                idx += 1
            b.append(idx)
        assert len(b) == 10
        tmp_data = x[b].view(-1, 28*28).to(device)
        target = y[b]
        example_data.append(tmp_data)

    example_data = torch.cat(example_data, dim=0).squeeze(1)
    example_data = example_data.view(-1, 28 * 28)
    net.eval()
    example_out = net(example_data, task_id, None, -1)

    batch_list = [300, 300, 300]
    mat_list = []  # list contains representation matrix of each layer
    act_key = list(net.act.keys())

    for i in range(len(act_key)):
        bsz = batch_list[i]
        act = net.act[act_key[i]].detach().cpu().numpy()
        activation = act[0:bsz].transpose()
        mat_list.append(activation)
        old_task_distribution[task_id][i].append(
            deepcopy(activation.flatten()))

    print('-'*30)
    print('Representation Matrix')
    print('-'*30)
    for i in range(len(mat_list)):
        print('Layer {} : {}'.format(i+1, mat_list[i].shape))
    print('-'*30)
    return mat_list

def get_representation_matrix_resnet18(task_id, net, device, x, y, old_task_distribution):
    # Collect activations by forward pass
    net.eval()
    r = np.arange(x.size(0))
    np.random.shuffle(r)
    r = torch.LongTensor(r).to(device)
    b = r[0:100]  # ns=100 examples
    example_data = x[b]
    example_data = example_data.to(device)
    example_out = net(example_data, task_id, None, -1)

    act_list = []
    act_list.extend([net.act['conv_in'],
                     net.layer1[0].act['conv_0'], net.layer1[0].act['conv_1'], net.layer1[1].act['conv_0'], net.layer1[1].act['conv_1'],
                     net.layer2[0].act['conv_0'], net.layer2[0].act['conv_1'], net.layer2[1].act['conv_0'], net.layer2[1].act['conv_1'],
                     net.layer3[0].act['conv_0'], net.layer3[0].act['conv_1'], net.layer3[1].act['conv_0'], net.layer3[1].act['conv_1'],
                     net.layer4[0].act['conv_0'], net.layer4[0].act['conv_1'], net.layer4[1].act['conv_0'], net.layer4[1].act['conv_1']])

    batch_list = [10, 10, 10, 10, 10, 10, 10, 10, 50,
                  50, 50, 100, 100, 100, 100, 100, 100]  # scaled
    # network arch
    stride_list = [2, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1]
    map_list = [84, 42, 42, 42, 42, 42, 21,
                21, 21, 21, 11, 11, 11, 11, 6, 6, 6]
    in_channel = [3, 20, 20, 20, 20, 20, 40, 40,
                  40, 40, 80, 80, 80, 80, 160, 160, 160]

    pad = 1
    sc_list = [5, 9, 13]
    p1d = (1, 1, 1, 1)
    mat_final = []  # list containing GPM Matrices
    mat_list = []
    mat_sc_list = []
    for i in range(len(stride_list)):
        if i == 0:
            ksz = 3
        else:
            ksz = 3
        bsz = batch_list[i]
        st = stride_list[i]
        k = 0
        s = compute_conv_output_size(map_list[i], ksz, stride_list[i], pad)
        mat = np.zeros((ksz*ksz*in_channel[i], s*s*bsz))
        act = F.pad(act_list[i], p1d, "constant", 0).detach().cpu().numpy()
        for kk in range(bsz):
            for ii in range(s):
                for jj in range(s):
                    mat[:, k] = act[kk, :, st*ii:ksz+st *
                                    ii, st*jj:ksz+st*jj].reshape(-1)
                    k += 1
        mat_list.append(mat)
        # For Shortcut Connection
        if i in sc_list:
            k = 0
            s = compute_conv_output_size(map_list[i], 1, stride_list[i])
            mat = np.zeros((1*1*in_channel[i], s*s*bsz))
            act = act_list[i].detach().cpu().numpy()
            for kk in range(bsz):
                for ii in range(s):
                    for jj in range(s):
                        mat[:, k] = act[kk, :, st*ii:1+st *
                                        ii, st*jj:1+st*jj].reshape(-1)
                        k += 1
            mat_sc_list.append(mat)

    ik = 0
    for i in range(len(mat_list)):
        mat_final.append(mat_list[i])
        if i in [6, 10, 14]:
            mat_final.append(mat_sc_list[ik])
            ik += 1

    for i in range(len(mat_final)):
        old_task_distribution[task_id][i].append(
            deepcopy(mat_final[i].flatten()))

    print('-'*30)
    print('Representation Matrix')
    print('-'*30)
    for i in range(len(mat_final)):
        print('Layer {} : {}'.format(i+1, mat_final[i].shape))
    print('-'*30)
    return mat_final

def get_representation_matrix_alexnet(task_id, net, device, x, y, old_task_distribution):
    example_data = []
    r = np.arange(x.size(0))
    np.random.shuffle(r)
    r = torch.LongTensor(r).to(device)
    un = torch.unique(y)
    idx = 0
    for _ in range(13):
        b = []
        for i in un:
            while y[idx] != i:
                idx += 1
            b.append(idx)
        assert len(b) == 10
        tmp_data = x[b].to(device)
        target = y[b]
        example_data.append(tmp_data)

    example_data = torch.cat(example_data, dim=0)
    net.eval()
    example_out = net(example_data, task_id, None, -1)

    batch_list = [2 * 12, 100, 100, 125, 125]
    mat_list = []
    act_key = list(net.act.keys())
    for i in range(len(net.map)):
        bsz = batch_list[i]
        k = 0
        if i < 3:
            ksz = net.ksize[i]
            s = compute_conv_output_size(net.map[i], net.ksize[i])
            mat = np.zeros(
                (net.ksize[i] * net.ksize[i] * net.in_channel[i], s * s * bsz))
            act = net.act[act_key[i]].detach().cpu().numpy()
            for kk in range(bsz):
                for ii in range(s):
                    for jj in range(s):
                        mat[:, k] = act[kk, :, ii:ksz + ii,
                                        jj:ksz + jj].reshape(-1)
                        k += 1
            mat_list.append(mat)
            old_task_distribution[task_id][i].append(deepcopy(mat.flatten()))
        else:
            act = net.act[act_key[i]].detach().cpu().numpy()
            activation = act[0:bsz].transpose()
            mat_list.append(activation)
            old_task_distribution[task_id][i].append(deepcopy(activation.flatten()))

    print('-' * 30)
    print('Representation Matrix')
    print('-' * 30)
    for i in range(len(mat_list)):
        print('Layer {} : {}'.format(i + 1, mat_list[i].shape))
    print('-' * 30)
    return mat_list

def get_representation_matrix_lenet(task_id, net, device, x, y, old_task_distribution):
    '''Get the representation matrix for the current task'''
    net.eval()
    example_data = []
    r = np.arange(x.size(0))
    np.random.shuffle(r)
    r = torch.LongTensor(r).to(device)
    un = torch.unique(y)
    idx = 0
    for _ in range(25):
        b = []
        for i in un:
            while y[idx] != i:
                idx += 1
            b.append(idx)
        assert len(b) == 5
        tmp_data = x[b].to(device)
        target = y[b]
        example_data.append(tmp_data)

    example_data = torch.cat(example_data, dim=0)
    example_data = example_data.to(device)
    example_out = net(example_data, task_id, None, -1)

    batch_list = [2*12, 100, 125, 125]
    pad = 2
    p1d = (2, 2, 2, 2)
    mat_list = []
    act_key = list(net.act.keys())
    # pdb.set_trace()
    for i in range(len(net.map)):
        bsz = batch_list[i]
        k = 0
        if i < 2:
            ksz = net.ksize[i]
            s = compute_conv_output_size(net.map[i], net.ksize[i], 1, pad)
            mat = np.zeros((net.ksize[i]*net.ksize[i]
                           * net.in_channel[i], s*s*bsz))
            act = F.pad(net.act[act_key[i]], p1d,
                        "constant", 0).detach().cpu().numpy()

            for kk in range(bsz):
                for ii in range(s):
                    for jj in range(s):
                        mat[:, k] = act[kk, :, ii:ksz+ii,
                                        jj:ksz+jj].reshape(-1)  # ?
                        k += 1
            mat_list.append(mat)
            old_task_distribution[task_id][i].append(deepcopy(mat.flatten()))
        else:
            act = net.act[act_key[i]].detach().cpu().numpy()
            activation = act[0:bsz].transpose()
            mat_list.append(activation)
            old_task_distribution[task_id][i].append(deepcopy(activation.flatten()))

    print('-'*30)
    print('Representation Matrix')
    print('-'*30)
    for i in range(len(mat_list)):
        print('Layer {} : {}'.format(i+1, mat_list[i].shape))
    print('-'*30)
    return mat_list

